import torch
from functools import partial


def reconstr_hook(activation, hook, sae_out):
    return sae_out

def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)

def mean_abl_hook(activation, hook):
    return activation.mean([0, 1]).expand_as(activation)

@torch.no_grad()
def log_model_performance(step, model, activations_store, sae, batch_tokens=None):
    if batch_tokens is None:
        batch_tokens = activations_store.get_batch_tokens()[:sae.config["batch_size"] // sae.config["seq_len"]]
    batch = activations_store.get_activations(batch_tokens).reshape(-1, sae.config["act_size"])

    sae_output = sae(batch)["sae_out"].reshape(batch_tokens.shape[0], batch_tokens.shape[1], -1)

    original_loss = model(batch_tokens, return_type="loss").item()
    reconstr_loss = model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[(sae.config["hook_point"], partial(reconstr_hook, sae_out=sae_output))],
        return_type="loss",
    ).item()
    zero_loss = model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[(sae.config["hook_point"], zero_abl_hook)],
        return_type="loss",
    ).item()
    mean_loss = model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[(sae.config["hook_point"], mean_abl_hook)],
        return_type="loss",
    ).item()

    ce_degradation = original_loss - reconstr_loss
    ce_ratio = original_loss/reconstr_loss
    ce_degradation_ratio = ce_degradation/original_loss
    zero_degradation = original_loss - zero_loss
    mean_degradation = original_loss - mean_loss

    # Calculate explained variance ratio
    total_variance = torch.var(batch, dim=0).sum().item()
    reconstruction_error = batch - sae_output.reshape(-1, sae.config["act_size"])
    unexplained_variance = torch.var(reconstruction_error, dim=0).sum().item()
    explained_variance_ratio = (total_variance - unexplained_variance) / total_variance

    log_dict = {
        "performance/ce_degradation": ce_degradation,
        "performance/ce_degradation_ratio": ce_degradation_ratio,
        "performance/ce_ratio": ce_ratio,
        "performance/recovery_from_zero": (reconstr_loss - zero_loss) / zero_degradation,
        "performance/recovery_from_mean": (reconstr_loss - mean_loss) / mean_degradation,
        "performance/explained_variance_ratio": explained_variance_ratio,
    }

    # Add step to the log dictionary
    log_dict["step"] = step

    return log_dict


import torch
import torch.nn as nn
from torch.nn import functional as F


@torch.no_grad()
def log_decoder_cosine_sim_quantiles(step, sae, quantiles=[0.01, 0.05, 0.1, 0.9, 0.95, 0.99], sample_size=100000):

    decoder_weights = sae.W_dec  # Shape: [dict_size, act_size]
    decoder_weights_norm = F.normalize(decoder_weights, p=2, dim=1)  # L2 normalize along features
    cosine_sim_matrix = torch.mm(decoder_weights_norm, decoder_weights_norm.T)  # Shape: [dict_size, dict_size]
    cosine_sim_matrix = cosine_sim_matrix - torch.eye(len(cosine_sim_matrix)).to('cuda')
    
    max_values, _ = torch.max(cosine_sim_matrix, dim=1) 
    min_values, _ = torch.min(cosine_sim_matrix, dim=1)
    
    mean_max = torch.mean(max_values).item()
    mean_min = torch.mean(min_values).item()
    
    
    
    # upper_triangular = torch.triu(cosine_sim_matrix, diagonal=1)
    
    # cosine_sims = upper_triangular[upper_triangular != 0]

    # if cosine_sims.numel() > sample_size:
    #     indices = torch.randperm(cosine_sims.numel(), device=cosine_sims.device)[:sample_size]
    #     cosine_sims_sampled = cosine_sims[indices]
    # else:
    #     cosine_sims_sampled = cosine_sims

    # # Ensure the quantiles tensor is on the same device as cosine_sims_sampled
    # quantiles_tensor = torch.tensor(quantiles, device=cosine_sims_sampled.device)

    # # Calculate the quantiles
    # quantile_values = torch.quantile(cosine_sims_sampled, quantiles_tensor).tolist()

    # Create a log dictionary
    # log_dict = {f"decoder_cosine_sim/q{int(q*100)}": val for q, val in zip(quantiles, quantile_values)}
    # log_dict["step"] = step
    
    log_dict = {
        "decoder_cosine_sim/mean_max": mean_max,
        "decoder_cosine_sim/mean_min": mean_min,
        }
        
    log_dict["step"] = step
    

    return log_dict